{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9b2bc6c7-f54c-436f-ab66-86a631fb75d8",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddfe7d95-3567-4cb2-9eb5-65f235113768",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fab1bcae-678b-4b19-a513-d0577d3d7e2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[ignite,pyyaml]\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8ae8b11-f5cf-4f91-ac60-8660f2ab2a4d",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1492c89-b19f-4216-b3a0-9960397e72ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from monai.apps import MedNISTDataset\n",
    "from monai.config import print_config\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2682936a-09ed-4703-af06-c59f755395ee",
   "metadata": {},
   "source": [
    "# MedNIST Classification Bundle\n",
    "\n",
    "In this tutorial we'll revisit the bundle replicating [MONAI 101 notebook](https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/monai_101.ipynb) and add more features representing best practice concepts. This will include evaluation and checkpoint saving techniques.\n",
    "\n",
    "We'll first create a bundle very much like in the previous tutorial with the same metadata and common script file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eb9dc6ec-13da-4a37-8afa-28e2766b9343",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/usr/bin/tree\n",
      "\u001b[01;34mMedNISTClassifier_v2\u001b[00m\n",
      "├── \u001b[01;34mconfigs\u001b[00m\n",
      "│   ├── inference.json\n",
      "│   └── metadata.json\n",
      "├── \u001b[01;34mdocs\u001b[00m\n",
      "│   └── README.md\n",
      "├── LICENSE\n",
      "└── \u001b[01;34mmodels\u001b[00m\n",
      "\n",
      "3 directories, 4 files\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "\n",
    "python -m monai.bundle init_bundle MedNISTClassifier_v2\n",
    "which tree && tree MedNISTClassifier_v2 || true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b29f053b-cf16-4ffc-bbe7-d9433fdfa872",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting MedNISTClassifier_v2/configs/metadata.json\n"
     ]
    }
   ],
   "source": [
    "%%writefile MedNISTClassifier_v2/configs/metadata.json\n",
    "\n",
    "{\n",
    "    \"version\": \"0.0.1\",\n",
    "    \"changelog\": {\n",
    "        \"0.0.1\": \"Initial version\"\n",
    "    },\n",
    "    \"monai_version\": \"1.2.0\",\n",
    "    \"pytorch_version\": \"2.0.0\",\n",
    "    \"numpy_version\": \"1.23.5\",\n",
    "    \"required_packages_version\": {},\n",
    "    \"name\": \"MedNISTClassifier\",\n",
    "    \"task\": \"MedNIST Classification Network\",\n",
    "    \"description\": \"This is a demo network for classifying MedNIST images by type/modality\",\n",
    "    \"authors\": \"Your Name Here\",\n",
    "    \"copyright\": \"Copyright (c) Your Name Here\",\n",
    "    \"data_source\": \"MedNIST dataset kindly made available by Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic)\",\n",
    "    \"data_type\": \"jpeg\",\n",
    "    \"intended_use\": \"This is suitable for demonstration only\",\n",
    "    \"network_data_format\": {\n",
    "        \"inputs\": {\n",
    "            \"image\": {\n",
    "                \"type\": \"image\",\n",
    "                \"format\": \"magnitude\",\n",
    "                \"modality\": \"any\",\n",
    "                \"num_channels\": 1,\n",
    "                \"spatial_shape\": [64, 64],\n",
    "                \"dtype\": \"float32\",\n",
    "                \"value_range\": [0, 1],\n",
    "                \"is_patch_data\": false,\n",
    "                \"channel_def\": {\n",
    "                    \"0\": \"image\"\n",
    "                }\n",
    "            }\n",
    "        },\n",
    "        \"outputs\": {\n",
    "            \"pred\": {\n",
    "                \"type\": \"probabilities\",\n",
    "                \"format\": \"classes\",\n",
    "                \"num_channels\": 6,\n",
    "                \"spatial_shape\": [6],\n",
    "                \"dtype\": \"float32\",\n",
    "                \"value_range\": [0, 1],\n",
    "                \"is_patch_data\": false,\n",
    "                \"channel_def\": {\n",
    "                    \"0\": \"AbdomenCT\",\n",
    "                    \"1\": \"BreastMRI\",\n",
    "                    \"2\": \"CXR\",\n",
    "                    \"3\": \"ChestCT\",\n",
    "                    \"4\": \"Hand\",\n",
    "                    \"5\": \"HeadCT\"\n",
    "                }\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04826c73-7c26-4c5e-8d2a-8968c3954b5a",
   "metadata": {},
   "source": [
    "As you've likely seen in outputs, there should be a `logging.conf` file in the `configs` directory to set up the Python logger appropriately. This will improve the output we get in the notebook:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0cb1b023-d192-4ad7-b2eb-c4a2c6b42b84",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing MedNISTClassifier_v2/configs/logging.conf\n"
     ]
    }
   ],
   "source": [
    "%%writefile MedNISTClassifier_v2/configs/logging.conf\n",
    "\n",
    "[loggers]\n",
    "keys=root\n",
    "\n",
    "[handlers]\n",
    "keys=consoleHandler\n",
    "\n",
    "[formatters]\n",
    "keys=fullFormatter\n",
    "\n",
    "[logger_root]\n",
    "level=INFO\n",
    "handlers=consoleHandler\n",
    "\n",
    "[handler_consoleHandler]\n",
    "class=StreamHandler\n",
    "level=INFO\n",
    "formatter=fullFormatter\n",
    "args=(sys.stdout,)\n",
    "\n",
    "[formatter_fullFormatter]\n",
    "format=%(asctime)s - %(name)s - %(levelname)s - %(message)s\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b306ff33-c39b-4822-b6d4-346987cfe87b",
   "metadata": {},
   "source": [
    "We'll change the common file slightly by adding some extra symbols, specifically `bundle_root` which should always be present in bundles. We'll keep `root_dir` since it's used to determine where MedNIST is downloaded to."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d11681af-3210-4b2b-b7bd-8ad8dedfe230",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing MedNISTClassifier_v2/configs/common.yaml\n"
     ]
    }
   ],
   "source": [
    "%%writefile MedNISTClassifier_v2/configs/common.yaml\n",
    "\n",
    "# added a few more imports\n",
    "imports: \n",
    "- $import torch\n",
    "- $import datetime\n",
    "- $import os\n",
    "\n",
    "root_dir: .\n",
    "\n",
    "# use constants from MONAI instead of hard-coding names\n",
    "image: $monai.utils.CommonKeys.IMAGE\n",
    "label: $monai.utils.CommonKeys.LABEL\n",
    "pred: $monai.utils.CommonKeys.PRED\n",
    "\n",
    "# these are added definitions\n",
    "bundle_root: .\n",
    "ckpt_path: $@bundle_root + '/models/model.pt'\n",
    "\n",
    "# define a device for the network\n",
    "device: '$torch.device(''cuda:0'')'\n",
    "\n",
    "# store the class names for inference later\n",
    "class_names: [AbdomenCT, BreastMRI, CXR, ChestCT, Hand, HeadCT]\n",
    "\n",
    "# define the network separately, don't need to refer to MONAI types by name or import MONAI\n",
    "network_def:\n",
    "  _target_: densenet121\n",
    "  spatial_dims: 2\n",
    "  in_channels: 1\n",
    "  out_channels: 6\n",
    "\n",
    "# define the network to be the given definition moved to the device\n",
    "net: '$@network_def.to(@device)'\n",
    "\n",
    "# define a transform sequence as a list of transform objects instead of using Compose here\n",
    "train_transforms:\n",
    "- _target_: LoadImaged\n",
    "  keys: '@image'\n",
    "  image_only: true\n",
    "- _target_: EnsureChannelFirstd\n",
    "  keys: '@image'\n",
    "- _target_: ScaleIntensityd\n",
    "  keys: '@image'\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eaf81ea7-9ea3-4548-a32e-992f0b9bc0ab",
   "metadata": {},
   "source": [
    "\n",
    "## Training\n",
    "\n",
    "For training we have the same elements again but we'll add a `SupervisedEvaluator` object to track model progress with handlers to save checkpoints. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4dfd052e-abe7-473a-bbf4-25674a3b20ea",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing MedNISTClassifier_v2/configs/train.yaml\n"
     ]
    }
   ],
   "source": [
    "%%writefile MedNISTClassifier_v2/configs/train.yaml\n",
    "\n",
    "max_epochs: 25\n",
    "learning_rate: 0.00001  # learning rate, again artificially slow\n",
    "val_interval: 1  # run validation every n'th epoch\n",
    "save_interval: 1  # save the model weights every n'th epoch\n",
    "\n",
    "# choose a unique output subdirectory every time training is started, \n",
    "output_dir: '$datetime.datetime.now().strftime(@root_dir+''/output/output_%y%m%d_%H%M%S'')'\n",
    "\n",
    "train_dataset:\n",
    "  _target_: MedNISTDataset\n",
    "  root_dir: '@root_dir'\n",
    "  transform: \n",
    "    _target_: Compose\n",
    "    transforms: '@train_transforms'\n",
    "  section: training\n",
    "  download: true\n",
    "\n",
    "train_dl:\n",
    "  _target_: DataLoader\n",
    "  dataset: '@train_dataset'\n",
    "  batch_size: 512\n",
    "  shuffle: true\n",
    "  num_workers: 4\n",
    "\n",
    "# separate dataset taking from the \"validation\" section\n",
    "eval_dataset:\n",
    "  _target_: MedNISTDataset\n",
    "  root_dir: '@root_dir'\n",
    "  transform: \n",
    "    _target_: Compose\n",
    "    transforms: '$@train_transforms'\n",
    "  section: validation\n",
    "  download: true\n",
    "\n",
    "# separate dataloader for evaluation\n",
    "eval_dl:\n",
    "  _target_: DataLoader\n",
    "  dataset: '@eval_dataset'\n",
    "  batch_size: 512\n",
    "  shuffle: false\n",
    "  num_workers: 4\n",
    "\n",
    "# transforms applied to network output, in this case applying activation, argmax, and one-hot-encoding\n",
    "post_transform:\n",
    "  _target_: Compose\n",
    "  transforms:\n",
    "  - _target_: Activationsd\n",
    "    keys: '@pred'\n",
    "    softmax: true  # apply softmax to the prediction to emphasize the most likely value\n",
    "  - _target_: AsDiscreted\n",
    "    keys: ['@label','@pred']\n",
    "    argmax: [false, true]  # apply argmax to the prediction only to get a class index number\n",
    "    to_onehot: 6  # convert both prediction and label to one-hot format so that both have shape (6,)\n",
    "\n",
    "# separating out loss, inferer, and optimizer definitions\n",
    "\n",
    "loss_function:\n",
    "  _target_: torch.nn.CrossEntropyLoss\n",
    "\n",
    "inferer: \n",
    "  _target_: SimpleInferer\n",
    "\n",
    "optimizer: \n",
    "  _target_: torch.optim.Adam\n",
    "  params: '$@net.parameters()'\n",
    "  lr: '@learning_rate'\n",
    "\n",
    "# Handlers to load the checkpoint if present, run validation at the chosen interval, save the checkpoint\n",
    "# at the chosen interval, log stats, and write the log to a file in the output directory.\n",
    "handlers:\n",
    "- _target_: CheckpointLoader\n",
    "  _disabled_: '$not os.path.exists(@ckpt_path)'\n",
    "  load_path: '@ckpt_path'\n",
    "  load_dict:\n",
    "    model: '@net'\n",
    "- _target_: ValidationHandler\n",
    "  validator: '@evaluator'\n",
    "  epoch_level: true\n",
    "  interval: '@val_interval'\n",
    "- _target_: CheckpointSaver\n",
    "  save_dir: '@output_dir'\n",
    "  save_dict:\n",
    "    model: '@net'\n",
    "  save_interval: '@save_interval'\n",
    "  save_final: true  # save the final weights, either when the run finishes or is interrupted somehow\n",
    "- _target_: StatsHandler\n",
    "  name: train_loss\n",
    "  tag_name: train_loss\n",
    "  output_transform: '$monai.handlers.from_engine([''loss''], first=True)'  # print per-iteration loss\n",
    "- _target_: LogfileHandler\n",
    "  output_dir: '@output_dir'\n",
    "\n",
    "trainer:\n",
    "  _target_: SupervisedTrainer\n",
    "  device: '@device'\n",
    "  max_epochs: '@max_epochs'\n",
    "  train_data_loader: '@train_dl'\n",
    "  network: '@net'\n",
    "  optimizer: '@optimizer'\n",
    "  loss_function: '@loss_function'\n",
    "  inferer: '@inferer'\n",
    "  train_handlers: '@handlers'\n",
    "\n",
    "# validation handlers which log stats and direct the log to a file\n",
    "val_handlers:\n",
    "- _target_: StatsHandler\n",
    "  name: val_stats\n",
    "  output_transform: '$lambda x: None'\n",
    "- _target_: LogfileHandler\n",
    "  output_dir: '@output_dir'\n",
    "\n",
    "# Metrics to assess validation results, you can have more than one here but may \n",
    "# need to adapt the format of pred and label.\n",
    "metrics:\n",
    "  accuracy:\n",
    "    _target_: 'ignite.metrics.Accuracy'\n",
    "    output_transform: '$monai.handlers.from_engine([@pred, @label])'\n",
    "\n",
    "# runs the evaluation process, invoked by trainer via the ValidationHandler object\n",
    "evaluator:\n",
    "  _target_: SupervisedEvaluator\n",
    "  device: '@device'\n",
    "  val_data_loader: '@eval_dl'\n",
    "  network: '@net'\n",
    "  inferer: '@inferer'\n",
    "  postprocessing: '@post_transform'\n",
    "  key_val_metric: '@metrics'\n",
    "  val_handlers: '@val_handlers'\n",
    "\n",
    "train:\n",
    "- '$@trainer.run()'\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de752181-80b1-4221-9e4a-315e5f7f22a6",
   "metadata": {},
   "source": [
    "We can now train as normal, specifying the logging config file and a maximum number of epochs you probably will want to set higher for a good result:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8357670d-fe69-4789-9b9a-77c0d8144b10",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "\n",
    "BUNDLE=\"./MedNISTClassifier_v2\"\n",
    "\n",
    "python -m monai.bundle run train \\\n",
    "    --bundle_root \"$BUNDLE\" \\\n",
    "    --logging_file \"$BUNDLE/configs/logging.conf\" \\\n",
    "    --meta_file \"$BUNDLE/configs/metadata.json\" \\\n",
    "    --config_file \"['$BUNDLE/configs/common.yaml','$BUNDLE/configs/train.yaml']\" \\\n",
    "    --max_epochs 2 &> out.txt || true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3d7e7e11-db67-47e3-a03d-0955feee1636",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# open(\"out.txt\").read()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "627bf8a5-1524-425f-93f8-28e217f2adec",
   "metadata": {},
   "source": [
    "Results and logs get put into unique timestamped directories:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "00c84e2c-1709-4136-8612-87142026ac2e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/usr/bin/tree\n",
      "\u001b[01;34moutput/output_230911_164547\u001b[00m\n",
      "├── log.txt\n",
      "├── model_epoch=1.pt\n",
      "├── model_epoch=2.pt\n",
      "└── model_final_iteration=186.pt\n",
      "\n",
      "0 directories, 4 files\n"
     ]
    }
   ],
   "source": [
    "!which tree && tree output/* || true"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5705ff79-fe58-410a-bb93-80b4f3fa2ea2",
   "metadata": {},
   "source": [
    "## Inference\n",
    "\n",
    "What is also needed is an inference script which will apply a loaded network to every image in a given directory and write a result to a file or to the log output. For segmentation networks this should save generated segmentations to know locations, but for this classification network we'll stick to just printing results to the log. \n",
    "\n",
    "First thing to do is create a test directory with only a few test images so we can demonstrate inference quickly:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3a957503-39e4-4f73-a989-ce6e4e2d3e9e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading dataset: 100%|██████████| 5895/5895 [00:01<00:00, 3696.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MedNIST/AbdomenCT/006070.jpeg Label: 0\n",
      "MedNIST/BreastMRI/006574.jpeg Label: 1\n",
      "MedNIST/ChestCT/009858.jpeg Label: 3\n",
      "MedNIST/CXR/007398.jpeg Label: 2\n",
      "MedNIST/Hand/005663.jpeg Label: 4\n",
      "MedNIST/HeadCT/006896.jpeg Label: 5\n",
      "MedNIST/HeadCT/007179.jpeg Label: 5\n",
      "MedNIST/CXR/001190.jpeg Label: 2\n",
      "MedNIST/ChestCT/005138.jpeg Label: 3\n",
      "MedNIST/BreastMRI/000023.jpeg Label: 1\n",
      "MedNIST/BreastMRI/006831.jpeg Label: 1\n",
      "MedNIST/BreastMRI/006561.jpeg Label: 1\n",
      "MedNIST/AbdomenCT/005808.jpeg Label: 0\n",
      "MedNIST/Hand/005363.jpeg Label: 4\n",
      "MedNIST/BreastMRI/001751.jpeg Label: 1\n",
      "MedNIST/CXR/002910.jpeg Label: 2\n",
      "MedNIST/ChestCT/003831.jpeg Label: 3\n",
      "MedNIST/HeadCT/006337.jpeg Label: 5\n",
      "MedNIST/AbdomenCT/005503.jpeg Label: 0\n",
      "MedNIST/Hand/003450.jpeg Label: 4\n"
     ]
    }
   ],
   "source": [
    "root_dir = \".\"  # assuming MedNIST was downloaded to the current directory\n",
    "num_images = 20\n",
    "dataset = MedNISTDataset(root_dir=root_dir, section=\"test\", download=False)\n",
    "\n",
    "!mkdir -p test_images\n",
    "\n",
    "for i in range(num_images):\n",
    "    filename = dataset[i][\"image\"].meta[\"filename_or_obj\"]\n",
    "    print(filename, \"Label:\", dataset[i][\"label\"])\n",
    "    !cp {root_dir}/{filename} test_images"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0044efdc-6c5e-479c-880b-acd9e7ab4fea",
   "metadata": {
    "tags": []
   },
   "source": [
    "Next remove the existing example inference script:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "7f800520-f29f-4b80-9af4-5e069f97824b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!rm \"MedNISTClassifier_v2/configs/inference.json\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef85014c-d1eb-4a93-911b-f405eac74094",
   "metadata": {},
   "source": [
    "Next we'll create the inference script which will apply the network to all the files in the given directory (thus assuming all are images) and save the results to a csv file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "3c5556db-2e63-484c-9358-977b4c35d60f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing MedNISTClassifier_v2/configs/inference.yaml\n"
     ]
    }
   ],
   "source": [
    "%%writefile MedNISTClassifier_v2/configs/inference.yaml\n",
    "\n",
    "imports:\n",
    "- $import glob\n",
    "\n",
    "input_dir: 'input'\n",
    "# dataset is a list of dictionaries to work with dictionary transforms\n",
    "input_files: '$[{@image: f} for f in sorted(glob.glob(@input_dir+''/*.*''))]'\n",
    "\n",
    "infer_dataset:\n",
    "  _target_: Dataset\n",
    "  data: '@input_files'\n",
    "  transform: \n",
    "    _target_: Compose\n",
    "    transforms: '@train_transforms'\n",
    "\n",
    "infer_dl:\n",
    "  _target_: DataLoader\n",
    "  dataset: '@infer_dataset'\n",
    "  batch_size: 1\n",
    "  shuffle: false\n",
    "  num_workers: 0\n",
    "\n",
    "# transforms applied to network output, same as those in training except \"label\" isn't present\n",
    "post_transform:\n",
    "  _target_: Compose\n",
    "  transforms:\n",
    "  - _target_: Activationsd\n",
    "    keys: '@pred'\n",
    "    softmax: true \n",
    "  - _target_: AsDiscreted\n",
    "    keys: ['@pred']\n",
    "    argmax: true \n",
    "\n",
    "# handlers to load the checkpoint file (and fail if a file isn't found), and save classification results to a csv file\n",
    "handlers:\n",
    "- _target_: CheckpointLoader\n",
    "  load_path: '@ckpt_path'\n",
    "  load_dict:\n",
    "    model: '@net'\n",
    "- _target_: ClassificationSaver\n",
    "  batch_transform: '$lambda batch: batch[0][@image].meta'\n",
    "  output_transform: '$monai.handlers.from_engine([''pred''])'\n",
    "\n",
    "inferer: \n",
    "  _target_: SimpleInferer\n",
    "\n",
    "evaluator:\n",
    "  _target_: SupervisedEvaluator\n",
    "  device: '@device'\n",
    "  val_data_loader: '@infer_dl'\n",
    "  network: '@net'\n",
    "  inferer: '@inferer'\n",
    "  postprocessing: '@post_transform'\n",
    "  val_handlers: '@handlers'\n",
    "\n",
    "inference:\n",
    "- '$@evaluator.run()'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e9a706a-b135-4943-8245-0da8d5dad415",
   "metadata": {},
   "source": [
    "Inference can now be run, specifying the checkpoint file to load as being one from our training run and the input directory as \"test_images\" which was created above:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "acdcc111-f259-4701-8b1d-31fcf74398bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2024-09-18 08:26:35,660 - INFO - --- input summary of monai.bundle.scripts.run ---\n",
      "2024-09-18 08:26:35,661 - INFO - > config_file: ['./MedNISTClassifier_v2/configs/common.yaml',\n",
      " './MedNISTClassifier_v2/configs/inference.yaml']\n",
      "2024-09-18 08:26:35,661 - INFO - > meta_file: './MedNISTClassifier_v2/configs/metadata.json'\n",
      "2024-09-18 08:26:35,661 - INFO - > logging_file: './MedNISTClassifier_v2/configs/logging.conf'\n",
      "2024-09-18 08:26:35,661 - INFO - > run_id: 'inference'\n",
      "2024-09-18 08:26:35,661 - INFO - > bundle_root: './MedNISTClassifier_v2'\n",
      "2024-09-18 08:26:35,661 - INFO - > ckpt_path: 'output/output_240918_082534/model_final_iteration=186.pt'\n",
      "2024-09-18 08:26:35,661 - INFO - > input_dir: 'test_images'\n",
      "2024-09-18 08:26:35,661 - INFO - ---\n",
      "\n",
      "\n",
      "2024-09-18 08:26:35,661 - INFO - Setting logging properties based on config: ./MedNISTClassifier_v2/configs/logging.conf.\n",
      "2024-09-18 08:26:35,899 - ignite.engine.engine.SupervisedEvaluator - INFO - Engine run resuming from iteration 0, epoch 0 until 1 epochs\n",
      "\n",
      "2024-09-18 08:26:35,958 - ignite.engine.engine.SupervisedEvaluator - INFO - Restored all variables from output/output_240918_082534/model_final_iteration=186.pt\n",
      "2024-09-18 08:26:36,545 - ignite.engine.engine.SupervisedEvaluator - INFO - Epoch[1] Complete. Time taken: 00:00:00.586\n",
      "2024-09-18 08:26:36,545 - ignite.engine.engine.SupervisedEvaluator - INFO - Engine run complete. Time taken: 00:00:00.646\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "\n",
    "BUNDLE=\"./MedNISTClassifier_v2\"\n",
    "# need to capture name since it'll be different for you\n",
    "ckpt=$(find output -name 'model_final_iteration=186.pt'|sort|tail -1)\n",
    "\n",
    "python -m monai.bundle run inference \\\n",
    "    --bundle_root \"$BUNDLE\" \\\n",
    "    --logging_file \"$BUNDLE/configs/logging.conf\" \\\n",
    "    --meta_file \"$BUNDLE/configs/metadata.json\" \\\n",
    "    --config_file \"['$BUNDLE/configs/common.yaml','$BUNDLE/configs/inference.yaml']\" \\\n",
    "    --ckpt_path \"$ckpt\" \\\n",
    "    --input_dir test_images "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "955faa08-0552-4bff-ba84-238e9a404f62",
   "metadata": {},
   "source": [
    "This will save the results of the inference to \"predictions.csv\" by default. You can change what the output filename is with an argument like `'--handlers#1#filename' pred.csv` which will directly change the `filename` parameter of the appropriate handler. Note the single quotes around the argument name since the hash sigil is interpreted by Bash as a comment otherwise.\n",
    "\n",
    "Looking at the output, the results aren't terribly legible:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "4a695039-7a53-4f9a-9754-769a9f8ebac8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test_images/000023.jpeg,1.0\n",
      "test_images/001190.jpeg,2.0\n",
      "test_images/001751.jpeg,1.0\n",
      "test_images/002910.jpeg,2.0\n",
      "test_images/003450.jpeg,4.0\n",
      "test_images/003831.jpeg,3.0\n",
      "test_images/005138.jpeg,3.0\n",
      "test_images/005363.jpeg,4.0\n",
      "test_images/005503.jpeg,0.0\n",
      "test_images/005663.jpeg,4.0\n",
      "test_images/005808.jpeg,0.0\n",
      "test_images/006070.jpeg,0.0\n",
      "test_images/006337.jpeg,5.0\n",
      "test_images/006561.jpeg,1.0\n",
      "test_images/006574.jpeg,1.0\n",
      "test_images/006831.jpeg,1.0\n",
      "test_images/006896.jpeg,5.0\n",
      "test_images/007179.jpeg,5.0\n",
      "test_images/007398.jpeg,2.0\n",
      "test_images/009858.jpeg,3.0\n"
     ]
    }
   ],
   "source": [
    "!cat predictions.csv"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a231c937-9ced-4a6d-b01c-3bc9a128fd62",
   "metadata": {},
   "source": [
    "The second column is the predicted class which we can use as an index into our list of class names to get something more readable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1065f928-3f66-47af-aed4-be2f0443cf2f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test_images/000023.jpeg BreastMRI\n",
      "test_images/001190.jpeg CXR\n",
      "test_images/001751.jpeg BreastMRI\n",
      "test_images/002910.jpeg CXR\n",
      "test_images/003450.jpeg Hand\n",
      "test_images/003831.jpeg ChestCT\n",
      "test_images/005138.jpeg ChestCT\n",
      "test_images/005363.jpeg Hand\n",
      "test_images/005503.jpeg AbdomenCT\n",
      "test_images/005663.jpeg Hand\n",
      "test_images/005808.jpeg AbdomenCT\n",
      "test_images/006070.jpeg AbdomenCT\n",
      "test_images/006337.jpeg HeadCT\n",
      "test_images/006561.jpeg BreastMRI\n",
      "test_images/006574.jpeg BreastMRI\n",
      "test_images/006831.jpeg BreastMRI\n",
      "test_images/006896.jpeg HeadCT\n",
      "test_images/007179.jpeg HeadCT\n",
      "test_images/007398.jpeg CXR\n",
      "test_images/009858.jpeg ChestCT\n"
     ]
    }
   ],
   "source": [
    "class_names = [\"AbdomenCT\", \"BreastMRI\", \"CXR\", \"ChestCT\", \"Hand\", \"HeadCT\"]\n",
    "\n",
    "for fn, idx in np.loadtxt(\"predictions.csv\", delimiter=\",\", dtype=str):\n",
    "    print(fn, class_names[int(float(idx))])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "235e90b9-9209-4a58-885d-042ab55c9c18",
   "metadata": {},
   "source": [
    "## Putting the Bundle Together\n",
    "\n",
    "We have a checkpoint for our network which produces good results that we can now make the \"official\" shared weights for the bundle. We need to copy the checkpoint into the `models` directory and optionally produce a Torchscript version of the network. \n",
    "\n",
    "For the Torchscript convertion MONAI provides the `ckpt_export` program in the bundles submodule:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c6672caa-fd51-4dde-a31d-5c4de8c3cc1d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2024-09-18 08:26:46,705 - INFO - --- input summary of monai.bundle.scripts.ckpt_export ---\n",
      "2024-09-18 08:26:46,706 - INFO - > net_id: 'network_def'\n",
      "2024-09-18 08:26:46,706 - INFO - > filepath: './MedNISTClassifier_v2/models/model.ts'\n",
      "2024-09-18 08:26:46,706 - INFO - > meta_file: './MedNISTClassifier_v2/configs/metadata.json'\n",
      "2024-09-18 08:26:46,706 - INFO - > config_file: ['./MedNISTClassifier_v2/configs/common.yaml',\n",
      " './MedNISTClassifier_v2/configs/inference.yaml']\n",
      "2024-09-18 08:26:46,706 - INFO - > ckpt_file: './MedNISTClassifier_v2/models/model.pt'\n",
      "2024-09-18 08:26:46,706 - INFO - > key_in_ckpt: 'model'\n",
      "2024-09-18 08:26:46,706 - INFO - > bundle_root: './MedNISTClassifier_v2'\n",
      "2024-09-18 08:26:46,706 - INFO - ---\n",
      "\n",
      "\n",
      "2023-09-11 16:57:12,519 - INFO - exported to file: ./MedNISTClassifier_v2/models/model.ts.\n",
      "/usr/bin/tree\n",
      "\u001b[01;34m./MedNISTClassifier_v2\u001b[00m\n",
      "├── \u001b[01;34mconfigs\u001b[00m\n",
      "│   ├── common.yaml\n",
      "│   ├── inference.yaml\n",
      "│   ├── logging.conf\n",
      "│   ├── metadata.json\n",
      "│   └── train.yaml\n",
      "├── \u001b[01;34mdocs\u001b[00m\n",
      "│   └── README.md\n",
      "├── LICENSE\n",
      "└── \u001b[01;34mmodels\u001b[00m\n",
      "    ├── model.pt\n",
      "    └── model.ts\n",
      "\n",
      "3 directories, 9 files\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2024-09-18 08:26:47,818 - INFO - exported to file: ./MedNISTClassifier_v2/models/model.ts.\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "\n",
    "BUNDLE=\"./MedNISTClassifier_v2\"\n",
    "\n",
    "ckpt=$(find output -name 'model_final_iteration=186.pt'|sort|tail -1)\n",
    "cp \"$ckpt\" \"$BUNDLE/models/model.pt\"\n",
    "\n",
    "python -m monai.bundle ckpt_export \\\n",
    "    --bundle_root \"$BUNDLE\" \\\n",
    "    --meta_file \"$BUNDLE/configs/metadata.json\" \\\n",
    "    --config_file \"['$BUNDLE/configs/common.yaml','$BUNDLE/configs/inference.yaml']\" \\\n",
    "    --net_id network_def \\\n",
    "    --key_in_ckpt model \\\n",
    "    --ckpt_file \"$BUNDLE/models/model.pt\" \\\n",
    "    --filepath \"$BUNDLE/models/model.ts\" \n",
    "\n",
    "which tree && tree \"$BUNDLE\" || true"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8def15f8-d0dc-4ed0-8bf7-669e0720ac81",
   "metadata": {},
   "source": [
    "This will have produced the `model.ts` file in `models` as shown here which can be loaded in Python without the bundle config scripts like any other Torchscript object.\n",
    "\n",
    "The arguments for the `ckpt_export` command specify the components to use in the config files and the checkpoint:\n",
    "* `bundle_root`, `meta_file`, and `config_file` are as in previous usages.\n",
    "* `net_id` specifies the object in the config files which represents the network definition, ie. the instantiated network object.\n",
    "* `key_in_ckpt` names the key under which the weights for the model are found in the checkpoint, this assumes the checkpoint is a dictionary which is what `CheckpointSaver` produces, if this file isn't a dictionary omit this argument.\n",
    "* `ckpt_file` the name of the checkpoint file itself\n",
    "* `filepath` the output filename to store the Torchscript object to."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18a62139-8a21-4bb9-96d4-e86d61298c40",
   "metadata": {},
   "source": [
    "## Summary and Next\n",
    "\n",
    "This tutorial has covered MONAI Bundle best practices:\n",
    "  * Separate common definition config files which are combined with specific application files\n",
    "  * Separating out definitions in config files for easier reading and changes\n",
    "  * Using Engine based classes for traning and validation\n",
    "  * Simple training run management with uniquely-created results directories\n",
    "  * Inference script to generate a results csv file containing predictions\n",
    "  \n",
    "The next tutorial will discuss creating bundles to wrap pre-existing Pytorch code so that you can get code into the bundle ecosystem without rewriting the world."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
